{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook is intended to demonstrate how select registration, segmentation, and image mathematical methods of ITKTubeTK can be combined to perform multi-channel brain extraction (aka. skull stripping for patient data containing multiple MRI sequences).\n",
    "\n",
    "There are many other (probably more effective) brain extraction methods available as open-source software such as BET and BET2 in the FSL package (albeit such methods are only for single channel data).   If you need to perform brain extraction for a large collection of scans that do not contain major pathologies, please use one of those packages.   This notebook is meant to show off the capabilities of specific ITKTubeTK methods, not to demonstration how to \"solve\" brain extraction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itk\n",
    "from itk import TubeTK as ttk\n",
    "\n",
    "from itkwidgets import view\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "ImageType = itk.Image[itk.F, 3]\n",
    "ReaderType = itk.ImageFileReader[ImageType]\n",
    "\n",
    "reader1 = ReaderType.New(FileName=\"../Data/MRI-Cases/mra.mha\")\n",
    "reader1.Update()\n",
    "im1 = reader1.GetOutput()\n",
    "\n",
    "reader2 = ReaderType.New(FileName=\"../Data/MRI-Cases/mri_t1_sag.mha\")\n",
    "reader2.Update()\n",
    "im2 = reader2.GetOutput()\n",
    "\n",
    "reader3 = ReaderType.New(FileName=\"../Data/MRI-Cases/mri_t2.mha\")\n",
    "reader3.Update()\n",
    "im3 = reader3.GetOutput()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Resample data into isotropic voxels and register T1 and T2 with MRA\n",
    "\n",
    "res = ttk.ResampleImage.New(Input = im1) \n",
    "res.SetMakeHighResIso(True)\n",
    "res.Update()\n",
    "im1iso = res.GetOutput()\n",
    "\n",
    "imReg = ttk.RegisterImages[ImageType].New()  # This is the standard protocol for within patient within visit registration\n",
    "imReg.SetFixedImage(im1iso)\n",
    "imReg.SetMovingImage(im2)\n",
    "imReg.SetReportProgress(True)\n",
    "imReg.SetExpectedOffsetMagnitude(40)\n",
    "imReg.SetExpectedRotationMagnitude(0.01)\n",
    "imReg.SetExpectedScaleMagnitude(0.01)\n",
    "imReg.SetRigidMaxIterations(500)\n",
    "imReg.SetRigidSamplingRatio(0.1)\n",
    "imReg.SetRegistration(\"RIGID\")\n",
    "imReg.SetMetric(\"MATTES_MI_METRIC\")\n",
    "imReg.Update()\n",
    "im2iso = imReg.GetFinalMovingImage()\n",
    "\n",
    "imReg = ttk.RegisterImages[ImageType].New()\n",
    "imReg.SetFixedImage(im1iso)\n",
    "imReg.SetMovingImage(im3)\n",
    "imReg.SetReportProgress(True)\n",
    "imReg.SetExpectedOffsetMagnitude(40)\n",
    "imReg.SetExpectedRotationMagnitude(0.01)\n",
    "imReg.SetExpectedScaleMagnitude(0.01)\n",
    "imReg.SetRigidMaxIterations(500)\n",
    "imReg.SetRigidSamplingRatio(0.1)\n",
    "imReg.SetRegistration(\"RIGID\")\n",
    "imReg.SetMetric(\"MATTES_MI_METRIC\")\n",
    "imReg.Update()\n",
    "im3iso = imReg.GetFinalMovingImage()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e4c50dcb99a646679ccc41500891b8a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "CheckerboardImagesType = itk.CheckerBoardImageFilter[ImageType]\n",
    "cb23 = CheckerboardImagesType.New(Input1=im2iso, Input2=im3iso)\n",
    "cb23.Update()\n",
    "im23 = ImageType.New()\n",
    "im23 = cb23.GetOutput()\n",
    "view(im23)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 8\n",
    "readerList = [\"003\", \"010\", \"026\", \"034\", \"045\", \"056\", \"063\", \"071\"]\n",
    "\n",
    "imBase = []\n",
    "imBaseB = []\n",
    "for i in range(0,N):\n",
    "    name = \"../data/MRI-Normals/Normal\"+readerList[i]+\"-FLASH.mha\"\n",
    "    nameB = \"../data/MRI-Normals/Normal\"+readerList[i]+\"-FLASH-Brain.mha\"\n",
    "    reader = ReaderType.New(FileName=name)\n",
    "    reader.Update()\n",
    "    imBaseTmp = reader.GetOutput()\n",
    "    reader = ReaderType.New(FileName=nameB)\n",
    "    reader.Update()\n",
    "    imBaseBTmp = reader.GetOutput()\n",
    "    imBase.append(imBaseTmp)\n",
    "    imBaseB.append(imBaseBTmp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The history saving thread hit an unexpected error (OperationalError('database is locked')).History will not be written to the database.\n"
     ]
    }
   ],
   "source": [
    "imMath = ttk.ImageMath.New(Input=im2iso)\n",
    "imMath.Blur(2)\n",
    "im2isoBlur = imMath.GetOutput()\n",
    "\n",
    "regB = []\n",
    "regBB = []\n",
    "for i in range(0,N):\n",
    "    imMath.SetInput(imBase[i])\n",
    "    imMath.Blur(2)\n",
    "    imBaseBlur = imMath.GetOutput()\n",
    "    regBTo1 = ttk.RegisterImages[ImageType].New(FixedImage=im2isoBlur, MovingImage=imBaseBlur)\n",
    "    regBTo1.SetReportProgress(True)\n",
    "    regBTo1.SetExpectedOffsetMagnitude(40)\n",
    "    regBTo1.SetExpectedRotationMagnitude(0.01)\n",
    "    regBTo1.SetExpectedScaleMagnitude(0.1)\n",
    "    regBTo1.SetRigidMaxIterations(500)\n",
    "    regBTo1.SetAffineMaxIterations(500)\n",
    "    regBTo1.SetRigidSamplingRatio(0.05)\n",
    "    regBTo1.SetAffineSamplingRatio(0.05)\n",
    "    regBTo1.SetInitialMethodEnum(\"INIT_WITH_IMAGE_CENTERS\")\n",
    "    regBTo1.SetRegistration(\"PIPELINE_AFFINE\")\n",
    "    regBTo1.SetMetric(\"MATTES_MI_METRIC\")\n",
    "    #regBTo1.SetMetric(\"NORMALIZED_CORRELATION_METRIC\") - Really slow!\n",
    "    #regBTo1.SetMetric(\"MEAN_SQUARED_ERROR_METRIC\")\n",
    "    regBTo1.Update()\n",
    "    img = regBTo1.ResampleImage(\"LINEAR\", imBase[i])\n",
    "    regB.append( img )\n",
    "    img = regBTo1.ResampleImage(\"LINEAR\", imBaseB[i])\n",
    "    regBB.append( img )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ee71ce031d7542999768eca9b2372aed",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "regBBT = []\n",
    "for i in range(0,N):\n",
    "    imMath = ttk.ImageMath[ImageType,ImageType].New( Input=regBB[i] )\n",
    "    imMath.Threshold(0,1,0,1)\n",
    "    img = imMath.GetOutput()\n",
    "    if i==0:\n",
    "        imMathSum = ttk.ImageMath[ImageType,ImageType].New( img )\n",
    "        imMathSum.AddImages( img, 1.0/N, 0 )\n",
    "        sumBBT = imMathSum.GetOutput()\n",
    "    else:\n",
    "        imMathSum = ttk.ImageMath[ImageType,ImageType].New( sumBBT )\n",
    "        imMathSum.AddImages( img, 1, 1.0/N )\n",
    "        sumBBT = imMathSum.GetOutput()\n",
    "        \n",
    "view(sumBBT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d726efd7438f4e86a97f6656f020a612",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "insideMath = ttk.ImageMath[ImageType,ImageType].New( Input = sumBBT )\n",
    "insideMath.Threshold(1,1,1,0)\n",
    "insideMath.Erode(5,1,0)\n",
    "brainInside = insideMath.GetOutput()\n",
    "\n",
    "outsideMath = ttk.ImageMath[ImageType,ImageType].New( Input = sumBBT )\n",
    "outsideMath.Threshold(0,0,1,0)\n",
    "outsideMath.Erode(10,1,0)\n",
    "brainOutsideAll = outsideMath.GetOutput()\n",
    "outsideMath.Erode(20,1,0)\n",
    "outsideMath.AddImages(brainOutsideAll, -1, 1)\n",
    "brainOutside = outsideMath.GetOutput()\n",
    "\n",
    "outsideMath.AddImages(brainInside,1,2)\n",
    "brainCombinedMask = outsideMath.GetOutputUChar()\n",
    "\n",
    "outsideMath.AddImages(im2iso, 512, 1)\n",
    "brainCombinedMaskView = outsideMath.GetOutput()\n",
    "view(brainCombinedMaskView)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "LabelMapType = itk.Image[itk.UC,3]\n",
    "\n",
    "segmenter = ttk.SegmentConnectedComponentsUsingParzenPDFs[ImageType,LabelMapType].New()\n",
    "segmenter.SetFeatureImage( im1iso )\n",
    "segmenter.AddFeatureImage( im2iso )\n",
    "segmenter.AddFeatureImage( im3iso )\n",
    "segmenter.SetInputLabelMap( brainCombinedMask )\n",
    "segmenter.SetObjectId( 2 )\n",
    "segmenter.AddObjectId( 1 )\n",
    "segmenter.SetVoidId( 0 )\n",
    "segmenter.SetErodeDilateRadius( 5 )\n",
    "segmenter.Update()\n",
    "segmenter.ClassifyImages()\n",
    "brainCombinedMaskClassified = segmenter.GetOutputLabelMap()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7c5c347788b544e787c351b2f198eddb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "cast = itk.CastImageFilter[LabelMapType, ImageType].New()\n",
    "cast.SetInput(brainCombinedMaskClassified)\n",
    "cast.Update()\n",
    "brainMaskF = cast.GetOutput()\n",
    "\n",
    "brainMath = ttk.ImageMath[ImageType,ImageType].New(Input = brainMaskF)\n",
    "brainMath.Threshold(2,2,1,0)\n",
    "brainMath.Dilate(2,1,0)\n",
    "brainMaskD = brainMath.GetOutput()\n",
    "brainMath.SetInput( im1 )\n",
    "brainMath.ReplaceValuesOutsideMaskRange( brainMaskD, 1, 1, 0)\n",
    "brain = brainMath.GetOutput()\n",
    "\n",
    "view(brain)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "itk.imwrite(brainMaskD, \"MRMask-Brain.mha\")\n",
    "\n",
    "itk.imwrite(im1iso, \"MRA-Iso.mha\")\n",
    "itk.imwrite(im2iso, \"MRT1-Iso.mha\")\n",
    "itk.imwrite(im3iso, \"MRT2-Iso.mha\")\n",
    "\n",
    "itk.imwrite(brain, \"MRA-Brain.mha\")\n",
    "\n",
    "brainMath.SetInput(im2iso)\n",
    "brainMath.ReplaceValuesOutsideMaskRange( brainMaskD, 1, 1, 0 )\n",
    "itk.imwrite(brainMath.GetOutput(), \"MRT1-Brain.mha\")\n",
    "\n",
    "brainMath.SetInput(im3iso)\n",
    "brainMath.ReplaceValuesOutsideMaskRange( brainMaskD, 1, 1, 0 )\n",
    "itk.imwrite(brainMath.GetOutput(), \"MRT2-Brain.mha\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
